{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.dilated_rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dilated RNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Dilated Recurrent Neural Network (`DilatedRNN`) addresses common challenges of modeling long sequences like vanishing gradients, computational efficiency, and improved model flexibility to model complex relationships while maintaining its parsimony. The `DilatedRNN` builds a deep stack of RNN layers using skip conditions on the temporal and the network's depth dimensions. The temporal dilated recurrent skip connections offer the capability to focus on multi-resolution inputs.The predictions are obtained by transforming the hidden states into contexts $\\mathbf{c}_{[t+1:t+H]}$, that are decoded and adapted into $\\mathbf{\\hat{y}}_{[t+1:t+H],[q]}$ through MLPs.\n",
    "\n",
    "\\begin{align}\n",
    " \\mathbf{h}_{t} &= \\textrm{DilatedRNN}([\\mathbf{y}_{t},\\mathbf{x}^{(h)}_{t},\\mathbf{x}^{(s)}], \\mathbf{h}_{t-1})\\\\\n",
    "\\mathbf{c}_{[t+1:t+H]}&=\\textrm{Linear}([\\mathbf{h}_{t}, \\mathbf{x}^{(f)}_{[:t+H]}]) \\\\ \n",
    "\\hat{y}_{\\tau,[q]}&=\\textrm{MLP}([\\mathbf{c}_{\\tau},\\mathbf{x}^{(f)}_{\\tau}])\n",
    "\\end{align}\n",
    "\n",
    "where $\\mathbf{h}_{t}$, is the hidden state for time $t$, $\\mathbf{y}_{t}$ is the input at time $t$ and $\\mathbf{h}_{t-1}$ is the hidden state of the previous layer at $t-1$, $\\mathbf{x}^{(s)}$ are static exogenous inputs, $\\mathbf{x}^{(h)}_{t}$ historic exogenous, $\\mathbf{x}^{(f)}_{[:t+H]}$ are future exogenous available at the time of the prediction.\n",
    "\n",
    "**References**<br>-[Shiyu Chang, et al. \"Dilated Recurrent Neural Networks\".](https://arxiv.org/abs/1710.02224)<br>-[Yao Qin, et al. \"A Dual-Stage Attention-Based recurrent neural network for time series prediction\".](https://arxiv.org/abs/1704.02971)<br>-[Kashif Rasul, et al. \"Zalando Research: PyTorch Dilated Recurrent Neural Networks\".](https://arxiv.org/abs/1710.02224)<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. Three layer DilatedRNN with dilation 1, 2, 4.](imgs_models/dilated_rnn.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "from nbdev.showdoc import show_doc\n",
    "from neuralforecast.utils import generate_series\n",
    "from neuralforecast.common._model_checks import check_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from typing import List, Optional\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE\n",
    "from neuralforecast.common._base_model import BaseModel\n",
    "from neuralforecast.common._modules import MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class LSTMCell(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, dropout=0.):\n",
    "        super(LSTMCell, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))\n",
    "        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))\n",
    "        self.bias_ih = nn.Parameter(torch.randn(4 * hidden_size))\n",
    "        self.bias_hh = nn.Parameter(torch.randn(4 * hidden_size))\n",
    "        self.dropout = dropout\n",
    "\n",
    "    def forward(self, inputs, hidden):\n",
    "        hx, cx = hidden[0].squeeze(0), hidden[1].squeeze(0)\n",
    "        gates = (torch.matmul(inputs, self.weight_ih.t()) + self.bias_ih +\n",
    "                         torch.matmul(hx, self.weight_hh.t()) + self.bias_hh)\n",
    "        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)\n",
    "\n",
    "        ingate = torch.sigmoid(ingate)\n",
    "        forgetgate = torch.sigmoid(forgetgate)\n",
    "        cellgate = torch.tanh(cellgate)\n",
    "        outgate = torch.sigmoid(outgate)\n",
    "\n",
    "        cy = (forgetgate * cx) + (ingate * cellgate)\n",
    "        hy = outgate * torch.tanh(cy)\n",
    "\n",
    "        return hy, (hy, cy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class ResLSTMCell(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, dropout=0.):\n",
    "        super(ResLSTMCell, self).__init__()\n",
    "        self.register_buffer('input_size', torch.Tensor([input_size]))\n",
    "        self.register_buffer('hidden_size', torch.Tensor([hidden_size]))\n",
    "        self.weight_ii = nn.Parameter(torch.randn(3 * hidden_size, input_size))\n",
    "        self.weight_ic = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))\n",
    "        self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))\n",
    "        self.bias_ii = nn.Parameter(torch.randn(3 * hidden_size))\n",
    "        self.bias_ic = nn.Parameter(torch.randn(3 * hidden_size))\n",
    "        self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))\n",
    "        self.weight_hh = nn.Parameter(torch.randn(1 * hidden_size, hidden_size))\n",
    "        self.bias_hh = nn.Parameter(torch.randn(1 * hidden_size))\n",
    "        self.weight_ir = nn.Parameter(torch.randn(hidden_size, input_size))\n",
    "        self.dropout = dropout\n",
    "\n",
    "    def forward(self, inputs, hidden):\n",
    "        hx, cx = hidden[0].squeeze(0), hidden[1].squeeze(0)\n",
    "\n",
    "        ifo_gates = (torch.matmul(inputs, self.weight_ii.t()) + self.bias_ii +\n",
    "                                  torch.matmul(hx, self.weight_ih.t()) + self.bias_ih +\n",
    "                                  torch.matmul(cx, self.weight_ic.t()) + self.bias_ic)\n",
    "        ingate, forgetgate, outgate = ifo_gates.chunk(3, 1)\n",
    "\n",
    "        cellgate = torch.matmul(hx, self.weight_hh.t()) + self.bias_hh\n",
    "\n",
    "        ingate = torch.sigmoid(ingate)\n",
    "        forgetgate = torch.sigmoid(forgetgate)\n",
    "        cellgate = torch.tanh(cellgate)\n",
    "        outgate = torch.sigmoid(outgate)\n",
    "\n",
    "        cy = (forgetgate * cx) + (ingate * cellgate)\n",
    "        ry = torch.tanh(cy)\n",
    "\n",
    "        if self.input_size == self.hidden_size:\n",
    "            hy = outgate * (ry + inputs)\n",
    "        else:\n",
    "            hy = outgate * (ry + torch.matmul(inputs, self.weight_ir.t()))\n",
    "        return hy, (hy, cy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class ResLSTMLayer(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, dropout=0.):\n",
    "        super(ResLSTMLayer, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.cell = ResLSTMCell(input_size, hidden_size, dropout=0.)\n",
    "\n",
    "    def forward(self, inputs, hidden):\n",
    "        inputs = inputs.unbind(0)\n",
    "        outputs = []\n",
    "        for i in range(len(inputs)):\n",
    "                out, hidden = self.cell(inputs[i], hidden)\n",
    "                outputs += [out]\n",
    "        outputs = torch.stack(outputs)\n",
    "        return outputs, hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class AttentiveLSTMLayer(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, dropout=0.0):\n",
    "        super(AttentiveLSTMLayer, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        attention_hsize = hidden_size\n",
    "        self.attention_hsize = attention_hsize\n",
    "\n",
    "        self.cell = LSTMCell(input_size, hidden_size)\n",
    "        self.attn_layer = nn.Sequential(nn.Linear(2 * hidden_size + input_size, attention_hsize),\n",
    "                                        nn.Tanh(),\n",
    "                                        nn.Linear(attention_hsize, 1))\n",
    "        self.softmax = nn.Softmax(dim=0)\n",
    "        self.dropout = dropout\n",
    "\n",
    "    def forward(self, inputs, hidden):\n",
    "        inputs = inputs.unbind(0)\n",
    "        outputs = []\n",
    "\n",
    "        for t in range(len(inputs)):\n",
    "            # attention on windows\n",
    "            hx, cx = (tensor.squeeze(0) for tensor in hidden)\n",
    "            hx_rep = hx.repeat(len(inputs), 1, 1)\n",
    "            cx_rep = cx.repeat(len(inputs), 1, 1)\n",
    "            x = torch.cat((inputs, hx_rep, cx_rep), dim=-1)\n",
    "            l = self.attn_layer(x)\n",
    "            beta = self.softmax(l)\n",
    "            context = torch.bmm(beta.permute(1, 2, 0),\n",
    "                                inputs.permute(1, 0, 2)).squeeze(1)\n",
    "            out, hidden = self.cell(context, hidden)\n",
    "            outputs += [out]\n",
    "        outputs = torch.stack(outputs)\n",
    "        return outputs, hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class DRNN(nn.Module):\n",
    "\n",
    "    def __init__(self, n_input, n_hidden, n_layers, dilations, dropout=0, cell_type='GRU', batch_first=True):\n",
    "        super(DRNN, self).__init__()\n",
    "\n",
    "        self.dilations = dilations\n",
    "        self.cell_type = cell_type\n",
    "        self.batch_first = batch_first\n",
    "\n",
    "        layers = []\n",
    "        if self.cell_type == \"GRU\":\n",
    "            cell = nn.GRU\n",
    "        elif self.cell_type == \"RNN\":\n",
    "            cell = nn.RNN\n",
    "        elif self.cell_type == \"LSTM\":\n",
    "            cell = nn.LSTM\n",
    "        elif self.cell_type == \"ResLSTM\":\n",
    "            cell = ResLSTMLayer\n",
    "        elif self.cell_type == \"AttentiveLSTM\":\n",
    "            cell = AttentiveLSTMLayer\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "\n",
    "        for i in range(n_layers):\n",
    "            if i == 0:\n",
    "                c = cell(n_input, n_hidden, dropout=dropout)\n",
    "            else:\n",
    "                c = cell(n_hidden, n_hidden, dropout=dropout)\n",
    "            layers.append(c)\n",
    "        self.cells = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, inputs, hidden=None):\n",
    "        if self.batch_first:\n",
    "            inputs = inputs.transpose(0, 1)\n",
    "        outputs = []\n",
    "        for i, (cell, dilation) in enumerate(zip(self.cells, self.dilations)):\n",
    "            if hidden is None:\n",
    "                inputs, _ = self.drnn_layer(cell, inputs, dilation)\n",
    "            else:\n",
    "                inputs, hidden[i] = self.drnn_layer(cell, inputs, dilation, hidden[i])\n",
    "\n",
    "            outputs.append(inputs[-dilation:])\n",
    "\n",
    "        if self.batch_first:\n",
    "            inputs = inputs.transpose(0, 1)\n",
    "        return inputs, outputs\n",
    "\n",
    "    def drnn_layer(self, cell, inputs, rate, hidden=None):\n",
    "        n_steps = len(inputs)\n",
    "        batch_size = inputs[0].size(0)\n",
    "        hidden_size = cell.hidden_size\n",
    "\n",
    "        inputs, dilated_steps = self._pad_inputs(inputs, n_steps, rate)\n",
    "        dilated_inputs = self._prepare_inputs(inputs, rate)\n",
    "\n",
    "        if hidden is None:\n",
    "            dilated_outputs, hidden = self._apply_cell(dilated_inputs, cell, batch_size, rate, hidden_size)\n",
    "        else:\n",
    "            hidden = self._prepare_inputs(hidden, rate)\n",
    "            dilated_outputs, hidden = self._apply_cell(dilated_inputs, cell, batch_size, rate, hidden_size,\n",
    "                                                       hidden=hidden)\n",
    "\n",
    "        splitted_outputs = self._split_outputs(dilated_outputs, rate)\n",
    "        outputs = self._unpad_outputs(splitted_outputs, n_steps)\n",
    "\n",
    "        return outputs, hidden\n",
    "\n",
    "    def _apply_cell(self, dilated_inputs, cell, batch_size, rate, hidden_size, hidden=None):\n",
    "        if hidden is None:\n",
    "            hidden = torch.zeros(batch_size * rate, hidden_size,\n",
    "                                 dtype=dilated_inputs.dtype,\n",
    "                                 device=dilated_inputs.device)\n",
    "            hidden = hidden.unsqueeze(0)\n",
    "            \n",
    "            if self.cell_type in ['LSTM', 'ResLSTM', 'AttentiveLSTM']:\n",
    "                hidden = (hidden, hidden)\n",
    "                \n",
    "        dilated_outputs, hidden = cell(dilated_inputs, hidden) # compatibility hack\n",
    "\n",
    "        return dilated_outputs, hidden\n",
    "\n",
    "    def _unpad_outputs(self, splitted_outputs, n_steps):\n",
    "        return splitted_outputs[:n_steps]\n",
    "\n",
    "    def _split_outputs(self, dilated_outputs, rate):\n",
    "        batchsize = dilated_outputs.size(1) // rate\n",
    "\n",
    "        blocks = [dilated_outputs[:, i * batchsize: (i + 1) * batchsize, :] for i in range(rate)]\n",
    "\n",
    "        interleaved = torch.stack((blocks)).transpose(1, 0)\n",
    "        interleaved = interleaved.reshape(dilated_outputs.size(0) * rate,\n",
    "                                       batchsize,\n",
    "                                       dilated_outputs.size(2))\n",
    "        return interleaved\n",
    "\n",
    "    def _pad_inputs(self, inputs, n_steps, rate):\n",
    "        iseven = (n_steps % rate) == 0\n",
    "\n",
    "        if not iseven:\n",
    "            dilated_steps = n_steps // rate + 1\n",
    "\n",
    "            zeros_ = torch.zeros(dilated_steps * rate - inputs.size(0),\n",
    "                                 inputs.size(1),\n",
    "                                 inputs.size(2), \n",
    "                                 dtype=inputs.dtype,\n",
    "                                 device=inputs.device)\n",
    "            inputs = torch.cat((inputs, zeros_))\n",
    "        else:\n",
    "            dilated_steps = n_steps // rate\n",
    "\n",
    "        return inputs, dilated_steps\n",
    "\n",
    "    def _prepare_inputs(self, inputs, rate):\n",
    "        dilated_inputs = torch.cat([inputs[j::rate, :, :] for j in range(rate)], 1)\n",
    "        return dilated_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class DilatedRNN(BaseModel):\n",
    "    \"\"\" DilatedRNN\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, forecast horizon.<br>\n",
    "    `input_size`: int, maximum sequence length for truncated train backpropagation. Default -1 uses 3 * horizon <br>\n",
    "    `inference_input_size`: int, maximum sequence length for truncated inference. Default None uses input_size history.<br>\n",
    "    `cell_type`: str, type of RNN cell to use. Options: 'GRU', 'RNN', 'LSTM', 'ResLSTM', 'AttentiveLSTM'.<br>\n",
    "    `dilations`: int list, dilations betweem layers.<br>\n",
    "    `encoder_hidden_size`: int=200, units for the RNN's hidden state size.<br>\n",
    "    `context_size`: int=10, size of context vector for each timestamp on the forecasting window.<br>\n",
    "    `decoder_hidden_size`: int=200, size of hidden layer for the MLP decoder.<br>\n",
    "    `decoder_layers`: int=2, number of layers for the MLP decoder.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int, maximum number of training steps.<br>\n",
    "    `learning_rate`: float, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch.<br>\n",
    "    `windows_batch_size`: int=128, number of windows to sample in each training batch, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch, -1 uses all.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>    \n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
    "    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
    "    `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
    "    `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br> \n",
    "    `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>    \n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    EXOGENOUS_FUTR = True\n",
    "    EXOGENOUS_HIST = True\n",
    "    EXOGENOUS_STAT = True\n",
    "    MULTIVARIATE = False    # If the model produces multivariate forecasts (True) or univariate (False)\n",
    "    RECURRENT = False       # If the model produces forecasts recursively (True) or direct (False)\n",
    "\n",
    "    def __init__(self,\n",
    "                 h: int,\n",
    "                 input_size: int = -1,\n",
    "                 inference_input_size: Optional[int] = None,\n",
    "                 cell_type: str = 'LSTM',\n",
    "                 dilations: List[List[int]] = [[1, 2], [4, 8]],\n",
    "                 encoder_hidden_size: int = 128,\n",
    "                 context_size: int = 10,\n",
    "                 decoder_hidden_size: int = 128,\n",
    "                 decoder_layers: int = 2,\n",
    "                 futr_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 stat_exog_list = None,\n",
    "                 exclude_insample_y = False,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 1000,\n",
    "                 learning_rate: float = 1e-3,\n",
    "                 num_lr_decays: int = 3,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size = 128,\n",
    "                 inference_windows_batch_size = 1024,\n",
    "                 start_padding_enabled = False,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'robust',\n",
    "                 random_seed: int = 1,\n",
    "                 drop_last_loader: bool = False,\n",
    "                 alias: Optional[str] = None,\n",
    "                 optimizer = None,\n",
    "                 optimizer_kwargs = None,\n",
    "                 lr_scheduler = None,\n",
    "                 lr_scheduler_kwargs = None,\n",
    "                 dataloader_kwargs = None,\n",
    "                 **trainer_kwargs):\n",
    "        super(DilatedRNN, self).__init__(\n",
    "            h=h,\n",
    "            input_size=input_size,\n",
    "            inference_input_size=inference_input_size,\n",
    "            futr_exog_list=futr_exog_list,\n",
    "            hist_exog_list=hist_exog_list,\n",
    "            stat_exog_list=stat_exog_list,\n",
    "            exclude_insample_y = exclude_insample_y,\n",
    "            loss=loss,\n",
    "            valid_loss=valid_loss,\n",
    "            max_steps=max_steps,\n",
    "            learning_rate=learning_rate,\n",
    "            num_lr_decays=num_lr_decays,\n",
    "            early_stop_patience_steps=early_stop_patience_steps,\n",
    "            val_check_steps=val_check_steps,\n",
    "            batch_size=batch_size,\n",
    "            valid_batch_size=valid_batch_size,\n",
    "            windows_batch_size=windows_batch_size,\n",
    "            inference_windows_batch_size=inference_windows_batch_size,\n",
    "            start_padding_enabled=start_padding_enabled,\n",
    "            step_size=step_size,\n",
    "            scaler_type=scaler_type,\n",
    "            random_seed=random_seed,\n",
    "            drop_last_loader=drop_last_loader,\n",
    "            alias=alias,\n",
    "            optimizer=optimizer,\n",
    "            optimizer_kwargs=optimizer_kwargs,\n",
    "            lr_scheduler=lr_scheduler,\n",
    "            lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
    "            dataloader_kwargs=dataloader_kwargs,\n",
    "            **trainer_kwargs\n",
    "        )\n",
    "\n",
    "        # Dilated RNN\n",
    "        self.cell_type = cell_type\n",
    "        self.dilations = dilations\n",
    "        self.encoder_hidden_size = encoder_hidden_size\n",
    "        \n",
    "        # Context adapter\n",
    "        self.context_size = context_size\n",
    "\n",
    "        # MLP decoder\n",
    "        self.decoder_hidden_size = decoder_hidden_size\n",
    "        self.decoder_layers = decoder_layers\n",
    "\n",
    "        # RNN input size (1 for target variable y)\n",
    "        input_encoder = 1 + self.hist_exog_size + self.stat_exog_size + self.futr_exog_size\n",
    "\n",
    "        # Instantiate model\n",
    "        layers = []\n",
    "        for grp_num in range(len(self.dilations)):\n",
    "            if grp_num > 0:\n",
    "                input_encoder = self.encoder_hidden_size\n",
    "            layer = DRNN(input_encoder,\n",
    "                         self.encoder_hidden_size,\n",
    "                         n_layers=len(self.dilations[grp_num]),\n",
    "                         dilations=self.dilations[grp_num],\n",
    "                         cell_type=self.cell_type)\n",
    "            layers.append(layer)\n",
    "\n",
    "        self.rnn_stack = nn.Sequential(*layers)\n",
    "\n",
    "        # Context adapter\n",
    "        self.context_adapter = nn.Linear(in_features=self.input_size,\n",
    "                                         out_features=h)\n",
    "\n",
    "        # Decoder MLP\n",
    "        self.mlp_decoder = MLP(in_features=self.encoder_hidden_size + self.futr_exog_size,\n",
    "                               out_features=self.loss.outputsize_multiplier,\n",
    "                               hidden_size=self.decoder_hidden_size,\n",
    "                               num_layers=self.decoder_layers,\n",
    "                               activation='ReLU',\n",
    "                               dropout=0.0)\n",
    "\n",
    "    def forward(self, windows_batch):\n",
    "        \n",
    "        # Parse windows_batch\n",
    "        encoder_input = windows_batch['insample_y']                         # [B, L, 1]\n",
    "        futr_exog     = windows_batch['futr_exog']                          # [B, L + h, F]\n",
    "        hist_exog     = windows_batch['hist_exog']                          # [B, L, X]\n",
    "        stat_exog     = windows_batch['stat_exog']                          # [B, S]\n",
    "\n",
    "        # Concatenate y, historic and static inputs              \n",
    "        batch_size, seq_len = encoder_input.shape[:2]\n",
    "        if self.hist_exog_size > 0:\n",
    "            encoder_input = torch.cat((encoder_input, hist_exog), dim=2)    # [B, L, 1] + [B, L, X] -> [B, L, 1 + X]\n",
    "\n",
    "        if self.stat_exog_size > 0:\n",
    "            stat_exog = stat_exog.unsqueeze(1).repeat(1, seq_len, 1)        # [B, S] -> [B, L, S]\n",
    "            encoder_input = torch.cat((encoder_input, stat_exog), dim=2)    # [B, L, 1 + X] + [B, L, S] -> [B, L, 1 + X + S]\n",
    "\n",
    "        if self.futr_exog_size > 0:\n",
    "            encoder_input = torch.cat((encoder_input, \n",
    "                                       futr_exog[:, :seq_len]), dim=2)      # [B, L, 1 + X + S] + [B, L, F] -> [B, L, 1 + X + S + F]\n",
    "\n",
    "        # DilatedRNN forward\n",
    "        for layer_num in range(len(self.rnn_stack)):\n",
    "            residual = encoder_input\n",
    "            output, _ = self.rnn_stack[layer_num](encoder_input)\n",
    "            if layer_num > 0:\n",
    "                output += residual\n",
    "            encoder_input = output\n",
    "\n",
    "        # Context adapter\n",
    "        output = output.permute(0, 2, 1)                                    # [B, L, C] -> [B, C, L]\n",
    "        context = self.context_adapter(output)                              # [B, C, L] -> [B, C, h]\n",
    "\n",
    "        # Residual connection with futr_exog\n",
    "        if self.futr_exog_size > 0:\n",
    "            futr_exog_futr = futr_exog[:, seq_len:].permute(0, 2, 1)        # [B, h, F] -> [B, F, h]\n",
    "            context = torch.cat((context, futr_exog_futr), \n",
    "                                dim=1)                                      # [B, C, h] + [B, F, h] = [B, C + F, h]\n",
    "\n",
    "        # Final forecast\n",
    "        context = context.permute(0, 2, 1)                                  # [B, C + F, h] -> [B, h, C + F]\n",
    "        output = self.mlp_decoder(context)                                  # [B, h, C + F] -> [B, h, n_output]\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/dilated_rnn.py#L289){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### DilatedRNN\n",
       "\n",
       ">      DilatedRNN (h:int, input_size:int, inference_input_size:int=-1,\n",
       ">                  cell_type:str='LSTM', dilations:List[List[int]]=[[1, 2], [4,\n",
       ">                  8]], encoder_hidden_size:int=200, context_size:int=10,\n",
       ">                  decoder_hidden_size:int=200, decoder_layers:int=2,\n",
       ">                  futr_exog_list=None, hist_exog_list=None,\n",
       ">                  stat_exog_list=None, exclude_insample_y=False, loss=MAE(),\n",
       ">                  valid_loss=None, max_steps:int=1000,\n",
       ">                  learning_rate:float=0.001, num_lr_decays:int=3,\n",
       ">                  early_stop_patience_steps:int=-1, val_check_steps:int=100,\n",
       ">                  batch_size=32, valid_batch_size:Optional[int]=None,\n",
       ">                  windows_batch_size=1024, inference_windows_batch_size=1024,\n",
       ">                  start_padding_enabled=False, step_size:int=1,\n",
       ">                  scaler_type:str='robust', random_seed:int=1,\n",
       ">                  num_workers_loader:int=0, drop_last_loader:bool=False,\n",
       ">                  optimizer=None, optimizer_kwargs=None, lr_scheduler=None,\n",
       ">                  lr_scheduler_kwargs=None, **trainer_kwargs)\n",
       "\n",
       "*DilatedRNN\n",
       "\n",
       "**Parameters:**<br>\n",
       "`h`: int, forecast horizon.<br>\n",
       "`input_size`: int, maximum sequence length for truncated train backpropagation. Default -1 uses all history.<br>\n",
       "`inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.<br>\n",
       "`cell_type`: str, type of RNN cell to use. Options: 'GRU', 'RNN', 'LSTM', 'ResLSTM', 'AttentiveLSTM'.<br>\n",
       "`dilations`: int list, dilations betweem layers.<br>\n",
       "`encoder_hidden_size`: int=200, units for the RNN's hidden state size.<br>\n",
       "`context_size`: int=10, size of context vector for each timestamp on the forecasting window.<br>\n",
       "`decoder_hidden_size`: int=200, size of hidden layer for the MLP decoder.<br>\n",
       "`decoder_layers`: int=2, number of layers for the MLP decoder.<br>\n",
       "`futr_exog_list`: str list, future exogenous columns.<br>\n",
       "`hist_exog_list`: str list, historic exogenous columns.<br>\n",
       "`stat_exog_list`: str list, static exogenous columns.<br>\n",
       "`loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
       "`valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
       "`max_steps`: int, maximum number of training steps.<br>\n",
       "`learning_rate`: float, Learning rate between (0, 1).<br>\n",
       "`num_lr_decays`: int, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
       "`early_stop_patience_steps`: int, Number of validation iterations before early stopping.<br>\n",
       "`val_check_steps`: int, Number of training steps between every validation loss check.<br>\n",
       "`batch_size`: int=32, number of different series in each batch.<br>\n",
       "`valid_batch_size`: int=None, number of different series in each validation and test batch.<br>\n",
       "`step_size`: int=1, step size between each window of temporal data.<br>\n",
       "`scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
       "`random_seed`: int=1, random_seed for pytorch initializer and numpy generators.<br>\n",
       "`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
       "`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
       "`alias`: str, optional,  Custom name of the model.<br>\n",
       "`optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
       "`optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
       "`lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
       "`lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br> \n",
       "`**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/dilated_rnn.py#L289){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### DilatedRNN\n",
       "\n",
       ">      DilatedRNN (h:int, input_size:int, inference_input_size:int=-1,\n",
       ">                  cell_type:str='LSTM', dilations:List[List[int]]=[[1, 2], [4,\n",
       ">                  8]], encoder_hidden_size:int=200, context_size:int=10,\n",
       ">                  decoder_hidden_size:int=200, decoder_layers:int=2,\n",
       ">                  futr_exog_list=None, hist_exog_list=None,\n",
       ">                  stat_exog_list=None, exclude_insample_y=False, loss=MAE(),\n",
       ">                  valid_loss=None, max_steps:int=1000,\n",
       ">                  learning_rate:float=0.001, num_lr_decays:int=3,\n",
       ">                  early_stop_patience_steps:int=-1, val_check_steps:int=100,\n",
       ">                  batch_size=32, valid_batch_size:Optional[int]=None,\n",
       ">                  windows_batch_size=1024, inference_windows_batch_size=1024,\n",
       ">                  start_padding_enabled=False, step_size:int=1,\n",
       ">                  scaler_type:str='robust', random_seed:int=1,\n",
       ">                  num_workers_loader:int=0, drop_last_loader:bool=False,\n",
       ">                  optimizer=None, optimizer_kwargs=None, lr_scheduler=None,\n",
       ">                  lr_scheduler_kwargs=None, **trainer_kwargs)\n",
       "\n",
       "*DilatedRNN\n",
       "\n",
       "**Parameters:**<br>\n",
       "`h`: int, forecast horizon.<br>\n",
       "`input_size`: int, maximum sequence length for truncated train backpropagation. Default -1 uses all history.<br>\n",
       "`inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.<br>\n",
       "`cell_type`: str, type of RNN cell to use. Options: 'GRU', 'RNN', 'LSTM', 'ResLSTM', 'AttentiveLSTM'.<br>\n",
       "`dilations`: int list, dilations betweem layers.<br>\n",
       "`encoder_hidden_size`: int=200, units for the RNN's hidden state size.<br>\n",
       "`context_size`: int=10, size of context vector for each timestamp on the forecasting window.<br>\n",
       "`decoder_hidden_size`: int=200, size of hidden layer for the MLP decoder.<br>\n",
       "`decoder_layers`: int=2, number of layers for the MLP decoder.<br>\n",
       "`futr_exog_list`: str list, future exogenous columns.<br>\n",
       "`hist_exog_list`: str list, historic exogenous columns.<br>\n",
       "`stat_exog_list`: str list, static exogenous columns.<br>\n",
       "`loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
       "`valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
       "`max_steps`: int, maximum number of training steps.<br>\n",
       "`learning_rate`: float, Learning rate between (0, 1).<br>\n",
       "`num_lr_decays`: int, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
       "`early_stop_patience_steps`: int, Number of validation iterations before early stopping.<br>\n",
       "`val_check_steps`: int, Number of training steps between every validation loss check.<br>\n",
       "`batch_size`: int=32, number of different series in each batch.<br>\n",
       "`valid_batch_size`: int=None, number of different series in each validation and test batch.<br>\n",
       "`step_size`: int=1, step size between each window of temporal data.<br>\n",
       "`scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
       "`random_seed`: int=1, random_seed for pytorch initializer and numpy generators.<br>\n",
       "`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
       "`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
       "`alias`: str, optional,  Custom name of the model.<br>\n",
       "`optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
       "`optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
       "`lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
       "`lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br> \n",
       "`**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(DilatedRNN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### DilatedRNN.fit\n",
       "\n",
       ">      DilatedRNN.fit (dataset, val_size=0, test_size=0, random_seed=None,\n",
       ">                      distributed_config=None)\n",
       "\n",
       "*Fit.\n",
       "\n",
       "The `fit` method, optimizes the neural network's weights using the\n",
       "initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n",
       "and the `loss` function as defined during the initialization.\n",
       "Within `fit` we use a PyTorch Lightning `Trainer` that\n",
       "inherits the initialization's `self.trainer_kwargs`, to customize\n",
       "its inputs, see [PL's trainer arguments](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).\n",
       "\n",
       "The method is designed to be compatible with SKLearn-like classes\n",
       "and in particular to be compatible with the StatsForecast library.\n",
       "\n",
       "By default the `model` is not saving training checkpoints to protect\n",
       "disk memory, to get them change `enable_checkpointing=True` in `__init__`.\n",
       "\n",
       "**Parameters:**<br>\n",
       "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
       "`val_size`: int, validation size for temporal cross-validation.<br>\n",
       "`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.<br>\n",
       "`test_size`: int, test size for temporal cross-validation.<br>*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### DilatedRNN.fit\n",
       "\n",
       ">      DilatedRNN.fit (dataset, val_size=0, test_size=0, random_seed=None,\n",
       ">                      distributed_config=None)\n",
       "\n",
       "*Fit.\n",
       "\n",
       "The `fit` method, optimizes the neural network's weights using the\n",
       "initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n",
       "and the `loss` function as defined during the initialization.\n",
       "Within `fit` we use a PyTorch Lightning `Trainer` that\n",
       "inherits the initialization's `self.trainer_kwargs`, to customize\n",
       "its inputs, see [PL's trainer arguments](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).\n",
       "\n",
       "The method is designed to be compatible with SKLearn-like classes\n",
       "and in particular to be compatible with the StatsForecast library.\n",
       "\n",
       "By default the `model` is not saving training checkpoints to protect\n",
       "disk memory, to get them change `enable_checkpointing=True` in `__init__`.\n",
       "\n",
       "**Parameters:**<br>\n",
       "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
       "`val_size`: int, validation size for temporal cross-validation.<br>\n",
       "`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.<br>\n",
       "`test_size`: int, test size for temporal cross-validation.<br>*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(DilatedRNN.fit, name='DilatedRNN.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### DilatedRNN.predict\n",
       "\n",
       ">      DilatedRNN.predict (dataset, test_size=None, step_size=1,\n",
       ">                          random_seed=None, **data_module_kwargs)\n",
       "\n",
       "*Predict.\n",
       "\n",
       "Neural network prediction with PL's `Trainer` execution of `predict_step`.\n",
       "\n",
       "**Parameters:**<br>\n",
       "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
       "`test_size`: int=None, test size for temporal cross-validation.<br>\n",
       "`step_size`: int=1, Step size between each window.<br>\n",
       "`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.<br>\n",
       "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### DilatedRNN.predict\n",
       "\n",
       ">      DilatedRNN.predict (dataset, test_size=None, step_size=1,\n",
       ">                          random_seed=None, **data_module_kwargs)\n",
       "\n",
       "*Predict.\n",
       "\n",
       "Neural network prediction with PL's `Trainer` execution of `predict_step`.\n",
       "\n",
       "**Parameters:**<br>\n",
       "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).<br>\n",
       "`test_size`: int=None, test size for temporal cross-validation.<br>\n",
       "`step_size`: int=1, Step size between each window.<br>\n",
       "`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.<br>\n",
       "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(DilatedRNN.predict, name='DilatedRNN.predict')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DilatedRNN: checking forecast AirPassengers dataset\n"
     ]
    }
   ],
   "source": [
    "#| hide\n",
    "# Unit tests for models\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "logging.getLogger(\"lightning_fabric\").setLevel(logging.ERROR)\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    check_model(DilatedRNN, [\"airpassengers\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import DilatedRNN\n",
    "from neuralforecast.losses.pytorch import DistributionLoss\n",
    "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "fcst = NeuralForecast(\n",
    "    models=[DilatedRNN(h=12,\n",
    "                       input_size=-1,\n",
    "                       loss=DistributionLoss(distribution='Normal', level=[80, 90]),\n",
    "                       scaler_type='robust',\n",
    "                       encoder_hidden_size=100,\n",
    "                       max_steps=200,\n",
    "                       futr_exog_list=['y_[lag12]'],\n",
    "                       hist_exog_list=None,\n",
    "                       stat_exog_list=['airline1'],\n",
    "    )\n",
    "    ],\n",
    "    freq='ME'\n",
    ")\n",
    "fcst.fit(df=Y_train_df, static_df=AirPassengersStatic)\n",
    "forecasts = fcst.predict(futr_df=Y_test_df)\n",
    "\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "plt.plot(plot_df['ds'], plot_df['DilatedRNN-median'], c='blue', label='median')\n",
    "plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                 y1=plot_df['DilatedRNN-lo-90'][-12:].values, \n",
    "                 y2=plot_df['DilatedRNN-hi-90'][-12:].values,\n",
    "                 alpha=0.4, label='level 90')\n",
    "plt.legend()\n",
    "plt.grid()\n",
    "plt.plot()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
